import os.path

import matplotlib.pyplot as plt
from osgeo import gdal
import numpy as np


def read_tiff(lst, evi):
    """
    Reads a tiff file and returns a numpy array
    """
    ds_lst = gdal.Open(lst)
    band_lst = ds_lst.GetRasterBand(1)
    data_lst = band_lst.ReadAsArray()

    ds_evi = gdal.Open(evi)
    band_evi = ds_evi.GetRasterBand(1)
    data_evi = band_evi.ReadAsArray(buf_xsize=data_lst.shape[1], buf_ysize=data_lst.shape[0])
    return data_lst, data_evi


def computed_TSVI(data_vi, data_lst):
    data_vi = ((data_vi - np.nanmin(data_vi)) / (np.nanmax(data_vi) - np.nanmin(data_vi))) ** 2
    # data_lst = ((data_lst - np.nanmin(data_lst)) / (np.nanmax(data_lst) - np.nanmin(data_lst))) ** 2
    """
    Computes the TSVI index
    """
    # 第一步，将evi划分为M个区间（建议M>=20)，每个区间分成N个子区间（建议N>=5)
    M = 100
    N = 10
    vi_interval = np.linspace(np.nanmin(data_vi), np.nanmax(data_vi), M + 1, endpoint=True)
    # print(vi_interval)
    # 第二步，对给定区间找出每个子区间的最高温度保存
    lst_max_list = []
    evi_interval_list = []
    for i in range(M):
        print("计算区间{}".format(i+1))
        vi_sub_interval = np.linspace(vi_interval[i], vi_interval[i + 1], N + 1, endpoint=True)
        print(vi_sub_interval)
        sub_max_list = []
        for j in range(N):
            sub_interval_left = vi_sub_interval[j]
            sub_interval_right = vi_sub_interval[j + 1]
            lst_sub_interval = data_lst[np.where((data_vi >= sub_interval_left) & (data_vi < sub_interval_right))]
            if lst_sub_interval.size != 0:
                lst_max = np.nanmax(lst_sub_interval)
                sub_max_list.append(lst_max)
        # 第三步，计算给定区间的N个子区间的标准偏差
        while True:
            lst_max_mean = np.nanmean(sub_max_list)
            if len(sub_max_list) > 2:
                lst_b = np.nanstd(sub_max_list)
                # 第四步，如果给定区间的每个子区间的最高温度小于平均值-偏差，则在下面步骤丢弃该子区间
                if sub_max_list[0] < (lst_max_mean - lst_b):
                    sub_max_list = sub_max_list[1:]
                else:
                    break
            else:
                break
        if np.isnan(lst_max_mean):
            continue
        else:
            lst_max_list.append(lst_max_mean)
            evi_interval_list.append(vi_interval[i])
    print(lst_max_list)
    print(evi_interval_list)
    # 第六步，如果给定区间内剩余子区间数量大于给定阈值且标准差大于给定阈值，则返回步骤四，重复步骤四-六。否则转到步骤七
    # 看是否给阈值，这里暂时没做
    # 第七步，取平均温度作为该区间的最高温度，并返回步骤二，知道找到M个区间内所有最高温度。
    # 前面循环已做
    # 第八步，每个最高温度和EVI之间的线性回归，并计算RMSE。
    a, b = 0, 0
    for i in range(M):
        lst_max_array = np.array(lst_max_list[i:-1])
        evi_interval_array = np.array(evi_interval_list[i:-1])
        popt = np.polyfit(evi_interval_array, lst_max_array, 1)
        a = popt[0]
        b = popt[1]
        print(a, b)
        lst_max_fit = a * evi_interval_array + b
        lst_max_fit_rmse = np.sqrt(np.mean((lst_max_array - lst_max_fit) ** 2))
        lst_max_mean = np.nanmean(lst_max_array)
        lst_max_max = np.nanmax(lst_max_array)
        # 第九步，如果给定区间最高温度是RMSE的2倍或大于等于回归线上的温度值，则该区间将被丢失，返回步骤八，直到达到最小区间数或不能在丢弃任何区间为止
        if lst_max_mean < lst_max_max - lst_max_fit_rmse * 2 and i > M / 2:
            continue
        else:
            break

    # 第十步，执行最后的线性回归得到干边，
    lst_max = a * vi_interval + b
    lst_min = np.ones_like(lst_max) * (a + b)
    plt.title("ALBEDO-LST")
    plt.scatter(data_vi.flatten(), data_lst.flatten(), c='r', marker='o')
    plt.plot(vi_interval, lst_max, c='b', label='max')
    plt.plot(vi_interval, lst_min, c='g', label='min')
    plt.xlabel("ALBEDO")
    plt.ylabel("LST")
    plt.legend()
    plt.show()


if __name__ == '__main__':
    # lst = r"G:\test\NDVI_LST\LST"
    # vi = r"G:\test\NDVI_LST\NDVI"
    # lst_list = [os.path.join(lst, f) for f in os.listdir(lst)]
    # vi_list = [os.path.join(vi, f) for f in os.listdir(vi)]
    # for i in range(len(lst_list)):
    #     data_lst, data_vi = read_tiff(lst_list[i], vi_list[i])
    #     if lst_list[i].split("\\")[-1].split(".")[0] == "2013207":
    #         data_lst = data_lst.astype(np.float32)
    #         data_vi = data_vi.astype(np.float32)
    #     else:
    #         data_lst = data_lst.astype(np.float32) * 0.02
    #         data_vi = data_vi.astype(np.float32) * 0.0001
    #     data_lst[data_lst == -9999] = np.nan
    #     data_vi[data_vi == -9999] = np.nan
    #     data_vi[data_vi < 0] = np.nan
    #     computed_TSVI(data_vi, data_lst)

    lst = r"G:\test\3D_TSVI\2013207_lst_day.tif"
    # albedo = r"G:\test\3D_TSVI\2013207_albedo.tif"
    albedo = r"G:\test\3D_TSVI\albedo\2013217.tif"
    evi = r"G:\test\3D_TSVI\2013207_evi_500m.tif"
    lst_dataset = gdal.Open(lst)
    lst_data = lst_dataset.ReadAsArray().astype(np.float32)
    evi_dataset = gdal.Open(evi)
    evi_data = evi_dataset.ReadAsArray(buf_xsize=lst_data.shape[1], buf_ysize=lst_data.shape[0]).astype(np.float32)
    albedo_dataset = gdal.Open(albedo)
    albedo_data = albedo_dataset.ReadAsArray(buf_xsize=lst_data.shape[1], buf_ysize=lst_data.shape[0]).astype(np.float32)
    lst_data[lst_data == -9999] = np.nan
    evi_data[evi_data == -9999] = np.nan
    albedo_data[albedo_data == -9999] = np.nan
    albedo_data = albedo_data
    computed_TSVI(albedo_data, lst_data)
    # computed_TSVI(evi_data, lst_data)
    # computed_TSVI(evi_data, albedo_data)